import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from scipy.io import loadmat
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTModel, ViTConfig
from pycocotools.coco import COCO
from PIL import Image
import os


def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)


class MultiLabelImageDataset(Dataset):
    def __init__(self, img_dir, labels_file, transform=None):
        self.img_dir = img_dir
        self.labels = pd.read_csv(labels_file)
        self.transform = transform
        self.num_classes = len(self.labels.columns) - 1  # Exclude the image name column

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = self.labels.iloc[idx, 0]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        labels = torch.tensor(self.labels.iloc[idx, 1:].values.astype(float), dtype=torch.float32)
        return image, labels


class ViTClassifier(nn.Module):
    def __init__(self, num_labels):
        super(ViTClassifier, self).__init__()
        self.vit = torch.load('./data/model/clip/vitB/vit_ini.pt')
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)

    def forward(self, x):

        outputs = self.vit(pixel_values=x)
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(cls_output)
        return logits


TRAIN_DIR = './data/dataset/coco-dataset-for-multi-label-image-classification/imgs/imgs/train'
TEST_DIR = './data/dataset/coco-dataset-for-multi-label-image-classification/imgs/imgs/test'
LABELS_FILE = './data/dataset/coco-dataset-for-multi-label-image-classification/labels/labels/labels_train.csv'
CATEGORIES_FILE = './data/dataset/coco-dataset-for-multi-label-image-classification/labels/labels/categories.csv'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
learning_rate = 1e-3
num_epochs = 100
train_dataset = MultiLabelImageDataset(TRAIN_DIR, LABELS_FILE, transform=transform)
datacanary = loadmat("./data/mat/clipmem/coco/canarylist.mat")
canarylist = datacanary['clist'].tolist()
canaryset = torch.utils.data.Subset(train_dataset, canarylist)
num_classes = train_dataset.num_classes
train_loader = DataLoader(canaryset, batch_size=128, shuffle=True, num_workers=4)

model = ViTClassifier(num_labels=num_classes).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    loss = train(model, train_loader, optimizer, criterion, device)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}')
torch.save(model, './data/model/clip/mineclip/vitB/trained/100_sl_01_f.pt')
